import json
from abc import ABC, abstractmethod
from datetime import datetime
from typing import List, Optional

# from colorama import Fore, Style, init
from rich.console import Console
from rich.live import Live
from rich.markup import escape

from letta.interface import CLIInterface
from letta.local_llm.constants import ASSISTANT_MESSAGE_CLI_SYMBOL, INNER_THOUGHTS_CLI_SYMBOL
from letta.schemas.message import Message
from letta.schemas.openai.chat_completion_response import ChatCompletionChunkResponse, ChatCompletionResponse

# init(autoreset=True)

# DEBUG = True  # puts full message outputs in the terminal
DEBUG = False  # only dumps important messages in the terminal

STRIP_UI = False


class AgentChunkStreamingInterface(ABC):
    """Interfaces handle Letta-related events (observer pattern)

    The 'msg' args provides the scoped message, and the optional Message arg can provide additional metadata.
    """

    @abstractmethod
    def user_message(self, msg: str, msg_obj: Optional[Message] = None):
        """Letta receives a user message"""
        raise NotImplementedError

    @abstractmethod
    def internal_monologue(self, msg: str, msg_obj: Optional[Message] = None, chunk_index: Optional[int] = None):
        """Letta generates some internal monologue"""
        raise NotImplementedError

    @abstractmethod
    def assistant_message(self, msg: str, msg_obj: Optional[Message] = None):
        """Letta uses send_message"""
        raise NotImplementedError

    @abstractmethod
    def function_message(self, msg: str, msg_obj: Optional[Message] = None, chunk_index: Optional[int] = None):
        """Letta calls a function"""
        raise NotImplementedError

    @abstractmethod
    def process_chunk(
        self,
        chunk: ChatCompletionChunkResponse,
        message_id: str,
        message_date: datetime,
        expect_reasoning_content: bool = False,
        name: Optional[str] = None,
        message_index: int = 0,
        prev_message_type: Optional[str] = None,
    ):
        """Process a streaming chunk from an OpenAI-compatible server"""
        raise NotImplementedError

    @abstractmethod
    def stream_start(self):
        """Any setup required before streaming begins"""
        raise NotImplementedError

    @abstractmethod
    def stream_end(self):
        """Any cleanup required after streaming ends"""
        raise NotImplementedError


class StreamingCLIInterface(AgentChunkStreamingInterface):
    """Version of the CLI interface that attaches to a stream generator and prints along the way.

    When a chunk is received, we write the delta to the buffer. If the buffer type has changed,
    we write out a newline + set the formatting for the new line.

    The two buffer types are:
      (1) content (inner thoughts)
      (2) tool_calls (function calling)

    NOTE: this assumes that the deltas received in the chunks are in-order, e.g.
    that once 'content' deltas stop streaming, they won't be received again. See notes
    on alternative version of the StreamingCLIInterface that does not have this same problem below:

    An alternative implementation could instead maintain the partial message state, and on each
    process chunk (1) update the partial message state, (2) refresh/rewrite the state to the screen.
    """

    # CLIInterface is static/stateless
    nonstreaming_interface = CLIInterface()

    def __init__(self):
        """The streaming CLI interface state for determining which buffer is currently being written to"""

        self.streaming_buffer_type = None

    def _flush(self):
        pass

    def process_chunk(
        self,
        chunk: ChatCompletionChunkResponse,
        message_id: str,
        message_date: datetime,
        expect_reasoning_content: bool = False,
        name: Optional[str] = None,
        message_index: int = 0,
        prev_message_type: Optional[str] = None,
    ):
        assert len(chunk.choices) == 1, chunk

        message_delta = chunk.choices[0].delta

        # Starting a new buffer line
        if not self.streaming_buffer_type:
            assert not (
                message_delta.content is not None and message_delta.tool_calls is not None and len(message_delta.tool_calls)
            ), f"Error: got both content and tool_calls in message stream\n{message_delta}"

            if message_delta.content is not None:
                # Write out the prefix for inner thoughts
                print("Inner thoughts: ", end="", flush=True)
            elif message_delta.tool_calls is not None:
                assert len(message_delta.tool_calls) == 1, f"Error: got more than one tool call in response\n{message_delta}"
                # Write out the prefix for function calling
                print("Calling function: ", end="", flush=True)

        # Potentially switch/flush a buffer line
        else:
            pass

        # Write out the delta
        if message_delta.content is not None:
            if self.streaming_buffer_type and self.streaming_buffer_type != "content":
                print()
            self.streaming_buffer_type = "content"

            # Simple, just write out to the buffer
            print(message_delta.content, end="", flush=True)

        elif message_delta.tool_calls is not None:
            if self.streaming_buffer_type and self.streaming_buffer_type != "tool_calls":
                print()
            self.streaming_buffer_type = "tool_calls"

            assert len(message_delta.tool_calls) == 1, f"Error: got more than one tool call in response\n{message_delta}"
            function_call = message_delta.tool_calls[0].function

            # Slightly more complex - want to write parameters in a certain way (paren-style)
            # function_name(function_args)
            if function_call and function_call.name:
                # NOTE: need to account for closing the brace later
                print(f"{function_call.name}(", end="", flush=True)
            if function_call and function_call.arguments:
                print(function_call.arguments, end="", flush=True)

    def stream_start(self):
        # should be handled by stream_end(), but just in case
        self.streaming_buffer_type = None

    def stream_end(self):
        if self.streaming_buffer_type is not None:
            # TODO: should have a separate self.tool_call_open_paren flag
            if self.streaming_buffer_type == "tool_calls":
                print(")", end="", flush=True)

            print()  # newline to move the cursor
            self.streaming_buffer_type = None  # reset buffer tracker

    @staticmethod
    def important_message(msg: str):
        StreamingCLIInterface.nonstreaming_interface(msg)

    @staticmethod
    def warning_message(msg: str):
        StreamingCLIInterface.nonstreaming_interface(msg)

    @staticmethod
    def internal_monologue(msg: str, msg_obj: Optional[Message] = None, chunk_index: Optional[int] = None):
        StreamingCLIInterface.nonstreaming_interface(msg, msg_obj)

    @staticmethod
    def assistant_message(msg: str, msg_obj: Optional[Message] = None):
        StreamingCLIInterface.nonstreaming_interface(msg, msg_obj)

    @staticmethod
    def memory_message(msg: str, msg_obj: Optional[Message] = None):
        StreamingCLIInterface.nonstreaming_interface(msg, msg_obj)

    @staticmethod
    def system_message(msg: str, msg_obj: Optional[Message] = None):
        StreamingCLIInterface.nonstreaming_interface(msg, msg_obj)

    @staticmethod
    def user_message(msg: str, msg_obj: Optional[Message] = None, raw: bool = False, dump: bool = False, debug: bool = DEBUG):
        StreamingCLIInterface.nonstreaming_interface(msg, msg_obj)

    @staticmethod
    def function_message(msg: str, msg_obj: Optional[Message] = None, debug: bool = DEBUG, chunk_index: Optional[int] = None):
        StreamingCLIInterface.nonstreaming_interface(msg, msg_obj)

    @staticmethod
    def print_messages(message_sequence: List[Message], dump=False):
        StreamingCLIInterface.nonstreaming_interface(message_sequence, dump)

    @staticmethod
    def print_messages_simple(message_sequence: List[Message]):
        StreamingCLIInterface.nonstreaming_interface.print_messages_simple(message_sequence)

    @staticmethod
    def print_messages_raw(message_sequence: List[Message]):
        StreamingCLIInterface.nonstreaming_interface.print_messages_raw(message_sequence)

    @staticmethod
    def step_yield():
        pass


class AgentRefreshStreamingInterface(ABC):
    """Same as the ChunkStreamingInterface, but

    The 'msg' args provides the scoped message, and the optional Message arg can provide additional metadata.
    """

    @abstractmethod
    def user_message(self, msg: str, msg_obj: Optional[Message] = None):
        """Letta receives a user message"""
        raise NotImplementedError

    @abstractmethod
    def internal_monologue(self, msg: str, msg_obj: Optional[Message] = None, chunk_index: Optional[int] = None):
        """Letta generates some internal monologue"""
        raise NotImplementedError

    @abstractmethod
    def assistant_message(self, msg: str, msg_obj: Optional[Message] = None):
        """Letta uses send_message"""
        raise NotImplementedError

    @abstractmethod
    def function_message(self, msg: str, msg_obj: Optional[Message] = None, chunk_index: Optional[int] = None):
        """Letta calls a function"""
        raise NotImplementedError

    @abstractmethod
    def process_refresh(self, response: ChatCompletionResponse):
        """Process a streaming chunk from an OpenAI-compatible server"""
        raise NotImplementedError

    @abstractmethod
    def stream_start(self):
        """Any setup required before streaming begins"""
        raise NotImplementedError

    @abstractmethod
    def stream_end(self):
        """Any cleanup required after streaming ends"""
        raise NotImplementedError

    @abstractmethod
    def toggle_streaming(self, on: bool):
        """Toggle streaming on/off (off = regular CLI interface)"""
        raise NotImplementedError


class StreamingRefreshCLIInterface(AgentRefreshStreamingInterface):
    """Version of the CLI interface that attaches to a stream generator and refreshes a render of the message at every step.

    We maintain the partial message state in the interface state, and on each
    process chunk we:
        (1) update the partial message state,
        (2) refresh/rewrite the state to the screen.
    """

    nonstreaming_interface = CLIInterface

    def __init__(self, fancy: bool = True, separate_send_message: bool = True, disable_inner_mono_call: bool = True):
        """Initialize the streaming CLI interface state."""
        self.console = Console()

        # Using `Live` with `refresh_per_second` parameter to limit the refresh rate, avoiding excessive updates
        self.live = Live("", console=self.console, refresh_per_second=10)
        # self.live.start()  # Start the Live display context and keep it running

        # Use italics / emoji?
        self.fancy = fancy

        self.streaming = True
        self.separate_send_message = separate_send_message
        self.disable_inner_mono_call = disable_inner_mono_call

    def toggle_streaming(self, on: bool):
        self.streaming = on
        if on:
            self.separate_send_message = True
            self.disable_inner_mono_call = True
        else:
            self.separate_send_message = False
            self.disable_inner_mono_call = False

    def update_output(self, content: str):
        """Update the displayed output with new content."""
        # We use the `Live` object's update mechanism to refresh content without clearing the console
        if not self.fancy:
            content = escape(content)
        self.live.update(self.console.render_str(content), refresh=True)

    def process_refresh(self, response: ChatCompletionResponse):
        """Process the response to rewrite the current output buffer."""
        if not response.choices:
            self.update_output(f"{INNER_THOUGHTS_CLI_SYMBOL} [italic]...[/italic]")
            return  # Early exit if there are no choices

        choice = response.choices[0]
        inner_thoughts = choice.message.content if choice.message.content else ""
        tool_calls = choice.message.tool_calls if choice.message.tool_calls else []

        if self.fancy:
            message_string = f"{INNER_THOUGHTS_CLI_SYMBOL} [italic]{inner_thoughts}[/italic]" if inner_thoughts else ""
        else:
            message_string = "[inner thoughts] " + inner_thoughts if inner_thoughts else ""

        if tool_calls:
            function_call = tool_calls[0].function
            function_name = function_call.name  # Function name, can be an empty string
            function_args = function_call.arguments  # Function arguments, can be an empty string
            if message_string:
                message_string += "\n"
            # special case here for send_message
            if self.separate_send_message and function_name == "send_message":
                try:
                    message = json.loads(function_args)["message"]
                except:
                    prefix = '{\n  "message": "'
                    if len(function_args) < len(prefix):
                        message = "..."
                    elif function_args.startswith(prefix):
                        message = function_args[len(prefix) :]
                    else:
                        message = function_args
                message_string += f"{ASSISTANT_MESSAGE_CLI_SYMBOL} [bold yellow]{message}[/bold yellow]"
            else:
                message_string += f"{function_name}({function_args})"

        self.update_output(message_string)

    def stream_start(self):
        if self.streaming:
            print()
            self.live.start()  # Start the Live display context and keep it running
            self.update_output(f"{INNER_THOUGHTS_CLI_SYMBOL} [italic]...[/italic]")

    def stream_end(self):
        if self.streaming:
            if self.live.is_started:
                self.live.stop()
                print()
                self.live = Live("", console=self.console, refresh_per_second=10)

    @staticmethod
    def important_message(msg: str):
        StreamingCLIInterface.nonstreaming_interface.important_message(msg)

    @staticmethod
    def warning_message(msg: str):
        StreamingCLIInterface.nonstreaming_interface.warning_message(msg)

    def internal_monologue(self, msg: str, msg_obj: Optional[Message] = None, chunk_index: Optional[int] = None):
        if self.disable_inner_mono_call:
            return
        StreamingCLIInterface.nonstreaming_interface.internal_monologue(msg, msg_obj)

    def assistant_message(self, msg: str, msg_obj: Optional[Message] = None):
        if self.separate_send_message:
            return
        StreamingCLIInterface.nonstreaming_interface.assistant_message(msg, msg_obj)

    @staticmethod
    def memory_message(msg: str, msg_obj: Optional[Message] = None):
        StreamingCLIInterface.nonstreaming_interface.memory_message(msg, msg_obj)

    @staticmethod
    def system_message(msg: str, msg_obj: Optional[Message] = None):
        StreamingCLIInterface.nonstreaming_interface.system_message(msg, msg_obj)

    @staticmethod
    def user_message(msg: str, msg_obj: Optional[Message] = None, raw: bool = False, dump: bool = False, debug: bool = DEBUG):
        StreamingCLIInterface.nonstreaming_interface.user_message(msg, msg_obj)

    @staticmethod
    def function_message(msg: str, msg_obj: Optional[Message] = None, debug: bool = DEBUG, chunk_index: Optional[int] = None):
        StreamingCLIInterface.nonstreaming_interface.function_message(msg, msg_obj)

    @staticmethod
    def print_messages(message_sequence: List[Message], dump=False):
        StreamingCLIInterface.nonstreaming_interface.print_messages(message_sequence, dump)

    @staticmethod
    def print_messages_simple(message_sequence: List[Message]):
        StreamingCLIInterface.nonstreaming_interface.print_messages_simple(message_sequence)

    @staticmethod
    def print_messages_raw(message_sequence: List[Message]):
        StreamingCLIInterface.nonstreaming_interface.print_messages_raw(message_sequence)

    @staticmethod
    def step_yield():
        pass
